{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cifar10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "\n",
    "# This line allows mpl to run with no DISPLAY defined\n",
    "mpl.use('Agg')\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "from keras.layers import Reshape, Flatten, LeakyReLU, Activation, Dense\n",
    "from keras.layers.convolutional import Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D\n",
    "from keras.models import Sequential\n",
    "from keras.optimizers import Adam\n",
    "from keras.callbacks import TensorBoard\n",
    "from keras_adversarial.image_grid_callback import ImageGridCallback\n",
    "\n",
    "from keras_adversarial import AdversarialModel, simple_gan, gan_targets\n",
    "from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling\n",
    "from keras_adversarial.legacy import BatchNormalization, fit, l1l2\n",
    "import keras.backend as K\n",
    "from cifar10_utils import cifar10_data\n",
    "from image_utils import dim_ordering_fix, dim_ordering_unfix, dim_ordering_shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# model generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def model_generator():\n",
    "    model = Sequential()\n",
    "    nch = 256\n",
    "    reg = lambda: l1l2(l1=1e-7, l2=1e-7)\n",
    "    h = 5\n",
    "    model.add(Dense(nch * 4 * 4, input_dim=100, W_regularizer=reg()))\n",
    "    model.add(BatchNormalization(mode=0))\n",
    "    model.add(Reshape(dim_ordering_shape((nch, 4, 4))))\n",
    "    model.add(Conv2D(int(nch / 2), (h, h), padding='same', kernel_regularizer=reg()))\n",
    "    model.add(BatchNormalization(mode=0, axis=1))\n",
    "    model.add(LeakyReLU(0.2))\n",
    "    model.add(UpSampling2D(size=(2, 2)))\n",
    "    model.add(Conv2D(int(nch / 2), (h, h), padding='same', kernel_regularizer=reg()))\n",
    "    model.add(BatchNormalization(mode=0, axis=1))\n",
    "    model.add(LeakyReLU(0.2))\n",
    "    model.add(UpSampling2D(size=(2, 2)))\n",
    "    model.add(Conv2D(int(nch / 4), (h, h), padding='same', kernel_regularizer=reg()))\n",
    "    model.add(BatchNormalization(mode=0, axis=1))\n",
    "    model.add(LeakyReLU(0.2))\n",
    "    model.add(UpSampling2D(size=(2, 2)))\n",
    "    model.add(Conv2D(3, (h, h), padding='same', kernel_regularizer=reg()))\n",
    "    model.add(Activation('sigmoid'))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# model descriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def model_discriminator():\n",
    "    nch = 256\n",
    "    h = 5\n",
    "    reg = lambda: l1l2(l1=1e-7, l2=1e-7)\n",
    "\n",
    "    c1 = Conv2D(int(nch / 4), (h, h), padding='same', kernel_regularizer=reg(),\n",
    "                       input_shape=(32, 32, 3))\n",
    "                       # input_shape=dim_ordering_shape((32, 32, 3)))\n",
    "    c2 = Conv2D(int(nch / 2), (h, h), padding='same', kernel_regularizer=reg())\n",
    "    c3 = Conv2D(nch, (h, h), padding='same', kernel_regularizer=reg())\n",
    "    c4 = Conv2D(1, (h, h), padding='same', kernel_regularizer=reg())\n",
    "\n",
    "    model = Sequential()\n",
    "    model.add(c1)\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add(LeakyReLU(0.2))\n",
    "    model.add(c2)\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add(LeakyReLU(0.2))\n",
    "    model.add(c3)\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add(LeakyReLU(0.2))\n",
    "    model.add(c4)\n",
    "    model.add(AveragePooling2D(pool_size=(4, 4), padding='valid'))\n",
    "    model.add(Flatten())\n",
    "    model.add(Activation('sigmoid'))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def example_gan(adversarial_optimizer, path, opt_g, opt_d, nb_epoch, generator, discriminator, latent_dim,\n",
    "                targets=gan_targets, loss='binary_crossentropy'):\n",
    "    csvpath = os.path.join(path, \"history.csv\")\n",
    "    if os.path.exists(csvpath):\n",
    "        print('Already exists: {}'.format(csvpath))\n",
    "        return\n",
    "    \n",
    "    print(\"Training: {}\".format(csvpath))\n",
    "    generator.summary()\n",
    "    discriminator.summary()\n",
    "    gan = simple_gan(generator=generator,\n",
    "                     discriminator=discriminator,\n",
    "                     latent_sampling=normal_latent_sampling((latent_dim, ))\n",
    "                    )\n",
    "    model = AdversarialModel(base_model=gan,\n",
    "                             player_params=[generator.trainable_weights, discriminator.trainable_weights],\n",
    "                             player_names=[\"generator\", \"descriminator\"]\n",
    "                         )\n",
    "    model.adversarial_compile(adversarial_optimizer=adversarial_optimizer,\n",
    "                              player_optimizers=[opt_g, opt_d], loss=loss)\n",
    "    \n",
    "    zsample = np.random.normal(size=(10 * 10, latent_dim))\n",
    "    \n",
    "    def generator_sampler():\n",
    "        xpred = dim_ordering_unfix(generator.predict(zsamples)).transpose((0, 2, 3, 1))\n",
    "        return xpred.reshape((10, 10) + xpred.shape[1:])\n",
    "    \n",
    "    generator_cb = ImageGridCallback(os.path.join(path, \"epoch-{:03d}.png\"), generator_sampler, cmap=None)\n",
    "    \n",
    "    xtrain, xtest = cifar10_data()\n",
    "    print(xtrain.shape)\n",
    "    # xtrain = dim_ordering_fix(xtrain)\n",
    "    # xtest = dim_ordering_fix(xtest)\n",
    "    y = targets(xtrain.shape[0])\n",
    "    ytest = targets(xtest.shape[0])\n",
    "    callbacks = [generator_cb]\n",
    "    if K.backend() == \"tensorflow\":\n",
    "        callbacks.append(TensorBoard(log_dir=os.path.join(path, 'logs'), histogram_freq=0, \n",
    "                                     write_graph=True, write_images=True))\n",
    "    history = fit(model, x=xtrain, y=y, validation_data=(xtest, ytest),\n",
    "                 batch_size=32)\n",
    "    \n",
    "    df = pd.DataFrame(history.history)\n",
    "    df.to_csv(csvpath)\n",
    "    \n",
    "    generator.save(os.path.join(path, 'generator.h5'))\n",
    "    descriminator.save(os.path.join(path, 'descriminator.h5'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Running model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.5/site-packages/ipykernel_launcher.py:6: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(4096, kernel_regularizer=<keras.reg..., input_dim=100)`\n",
      "  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training: ./output/gan-cifar10/history.csv\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_3 (Dense)              (None, 4096)              413696    \n",
      "_________________________________________________________________\n",
      "batch_normalization_9 (Batch (None, 4096)              16384     \n",
      "_________________________________________________________________\n",
      "reshape_3 (Reshape)          (None, 4, 4, 256)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_17 (Conv2D)           (None, 4, 4, 128)         819328    \n",
      "_________________________________________________________________\n",
      "batch_normalization_10 (Batc (None, 4, 4, 128)         16        \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_13 (LeakyReLU)   (None, 4, 4, 128)         0         \n",
      "_________________________________________________________________\n",
      "up_sampling2d_7 (UpSampling2 (None, 8, 8, 128)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_18 (Conv2D)           (None, 8, 8, 128)         409728    \n",
      "_________________________________________________________________\n",
      "batch_normalization_11 (Batc (None, 8, 8, 128)         32        \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_14 (LeakyReLU)   (None, 8, 8, 128)         0         \n",
      "_________________________________________________________________\n",
      "up_sampling2d_8 (UpSampling2 (None, 16, 16, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv2d_19 (Conv2D)           (None, 16, 16, 64)        204864    \n",
      "_________________________________________________________________\n",
      "batch_normalization_12 (Batc (None, 16, 16, 64)        64        \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_15 (LeakyReLU)   (None, 16, 16, 64)        0         \n",
      "_________________________________________________________________\n",
      "up_sampling2d_9 (UpSampling2 (None, 32, 32, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_20 (Conv2D)           (None, 32, 32, 3)         4803      \n",
      "_________________________________________________________________\n",
      "activation_5 (Activation)    (None, 32, 32, 3)         0         \n",
      "=================================================================\n",
      "Total params: 1,868,915\n",
      "Trainable params: 1,860,667\n",
      "Non-trainable params: 8,248\n",
      "_________________________________________________________________\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_21 (Conv2D)           (None, 32, 32, 64)        4864      \n",
      "_________________________________________________________________\n",
      "max_pooling2d_7 (MaxPooling2 (None, 16, 16, 64)        0         \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_16 (LeakyReLU)   (None, 16, 16, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_22 (Conv2D)           (None, 16, 16, 128)       204928    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_8 (MaxPooling2 (None, 8, 8, 128)         0         \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_17 (LeakyReLU)   (None, 8, 8, 128)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_23 (Conv2D)           (None, 8, 8, 256)         819456    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_9 (MaxPooling2 (None, 4, 4, 256)         0         \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_18 (LeakyReLU)   (None, 4, 4, 256)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_24 (Conv2D)           (None, 4, 4, 1)           6401      \n",
      "_________________________________________________________________\n",
      "average_pooling2d_3 (Average (None, 1, 1, 1)           0         \n",
      "_________________________________________________________________\n",
      "flatten_3 (Flatten)          (None, 1)                 0         \n",
      "_________________________________________________________________\n",
      "activation_6 (Activation)    (None, 1)                 0         \n",
      "=================================================================\n",
      "Total params: 1,035,649\n",
      "Trainable params: 1,035,649\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "(50000, 32, 32, 3)\n",
      "Train on 50000 samples, validate on 10000 samples\n",
      "Epoch 1/10\n",
      "50000/50000 [==============================] - 1606s - loss: 11.4726 - generator_loss: 11.0242 - generator_yfake_loss: 4.5467 - generator_yreal_loss: 6.4356 - descriminator_loss: 0.4483 - descriminator_yfake_loss: 0.1817 - descriminator_yreal_loss: 0.2247 - val_loss: 9.8941 - val_generator_loss: 9.5857 - val_generator_yfake_loss: 5.3131 - val_generator_yreal_loss: 4.2227 - val_descriminator_loss: 0.3084 - val_descriminator_yfake_loss: 0.0141 - val_descriminator_yreal_loss: 0.2444\n",
      "Epoch 2/10\n",
      "50000/50000 [==============================] - 1600s - loss: 12.8192 - generator_loss: 12.4436 - generator_yfake_loss: 5.3174 - generator_yreal_loss: 7.0673 - descriminator_loss: 0.3757 - descriminator_yfake_loss: 0.1381 - descriminator_yreal_loss: 0.1787 - val_loss: 10.5536 - val_generator_loss: 10.0958 - val_generator_yfake_loss: 2.7463 - val_generator_yreal_loss: 7.2833 - val_descriminator_loss: 0.4578 - val_descriminator_yfake_loss: 0.2997 - val_descriminator_yreal_loss: 0.0919\n",
      "Epoch 3/10\n",
      "50000/50000 [==============================] - 1595s - loss: 11.9600 - generator_loss: 11.5703 - generator_yfake_loss: 4.6330 - generator_yreal_loss: 6.8654 - descriminator_loss: 0.3897 - descriminator_yfake_loss: 0.1380 - descriminator_yreal_loss: 0.1798 - val_loss: 9.9060 - val_generator_loss: 9.3688 - val_generator_yfake_loss: 5.3284 - val_generator_yreal_loss: 3.9631 - val_descriminator_loss: 0.5372 - val_descriminator_yfake_loss: 0.0200 - val_descriminator_yreal_loss: 0.4398\n",
      "Epoch 4/10\n",
      "50000/50000 [==============================] - 1582s - loss: 12.0365 - generator_loss: 11.6670 - generator_yfake_loss: 4.5909 - generator_yreal_loss: 6.9941 - descriminator_loss: 0.3695 - descriminator_yfake_loss: 0.1281 - descriminator_yreal_loss: 0.1595 - val_loss: 13.4516 - val_generator_loss: 12.9344 - val_generator_yfake_loss: 7.4693 - val_generator_yreal_loss: 5.3792 - val_descriminator_loss: 0.5172 - val_descriminator_yfake_loss: 0.0058 - val_descriminator_yreal_loss: 0.4254\n",
      "Epoch 5/10\n",
      "50000/50000 [==============================] - 1582s - loss: 12.5079 - generator_loss: 12.1342 - generator_yfake_loss: 4.6881 - generator_yreal_loss: 7.3555 - descriminator_loss: 0.3737 - descriminator_yfake_loss: 0.1250 - descriminator_yreal_loss: 0.1580 - val_loss: 11.3588 - val_generator_loss: 11.0235 - val_generator_yfake_loss: 4.8737 - val_generator_yreal_loss: 6.0548 - val_descriminator_loss: 0.3353 - val_descriminator_yfake_loss: 0.0454 - val_descriminator_yreal_loss: 0.1950\n",
      "Epoch 6/10\n",
      "50000/50000 [==============================] - 1585s - loss: 12.6611 - generator_loss: 12.2647 - generator_yfake_loss: 4.6718 - generator_yreal_loss: 7.4941 - descriminator_loss: 0.3964 - descriminator_yfake_loss: 0.1317 - descriminator_yreal_loss: 0.1659 - val_loss: 12.3730 - val_generator_loss: 12.0400 - val_generator_yfake_loss: 3.2166 - val_generator_yreal_loss: 8.7207 - val_descriminator_loss: 0.3330 - val_descriminator_yfake_loss: 0.1831 - val_descriminator_yreal_loss: 0.0472\n",
      "Epoch 7/10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "50000/50000 [==============================] - 1583s - loss: 12.6415 - generator_loss: 12.2515 - generator_yfake_loss: 4.5163 - generator_yreal_loss: 7.6293 - descriminator_loss: 0.3900 - descriminator_yfake_loss: 0.1258 - descriminator_yreal_loss: 0.1582 - val_loss: 12.1586 - val_generator_loss: 11.7245 - val_generator_yfake_loss: 5.2948 - val_generator_yreal_loss: 6.3203 - val_descriminator_loss: 0.4342 - val_descriminator_yfake_loss: 0.0772 - val_descriminator_yreal_loss: 0.2475\n",
      "Epoch 8/10\n",
      "50000/50000 [==============================] - 1578s - loss: 12.2446 - generator_loss: 11.8261 - generator_yfake_loss: 4.3395 - generator_yreal_loss: 7.3740 - descriminator_loss: 0.4185 - descriminator_yfake_loss: 0.1374 - descriminator_yreal_loss: 0.1686 - val_loss: 12.4301 - val_generator_loss: 11.9971 - val_generator_yfake_loss: 5.0931 - val_generator_yreal_loss: 6.7885 - val_descriminator_loss: 0.4329 - val_descriminator_yfake_loss: 0.0487 - val_descriminator_yreal_loss: 0.2686\n",
      "Epoch 9/10\n",
      "50000/50000 [==============================] - 1578s - loss: 12.5830 - generator_loss: 12.1759 - generator_yfake_loss: 4.4437 - generator_yreal_loss: 7.6139 - descriminator_loss: 0.4071 - descriminator_yfake_loss: 0.1295 - descriminator_yreal_loss: 0.1593 - val_loss: 13.1148 - val_generator_loss: 12.7203 - val_generator_yfake_loss: 5.2060 - val_generator_yreal_loss: 7.3932 - val_descriminator_loss: 0.3945 - val_descriminator_yfake_loss: 0.0358 - val_descriminator_yreal_loss: 0.2377\n",
      "Epoch 10/10\n",
      "50000/50000 [==============================] - 1578s - loss: 12.8239 - generator_loss: 12.4150 - generator_yfake_loss: 4.4697 - generator_yreal_loss: 7.8219 - descriminator_loss: 0.4089 - descriminator_yfake_loss: 0.1281 - descriminator_yreal_loss: 0.1574 - val_loss: 12.6897 - val_generator_loss: 12.3269 - val_generator_yfake_loss: 4.2452 - val_generator_yreal_loss: 7.9559 - val_descriminator_loss: 0.3629 - val_descriminator_yfake_loss: 0.1244 - val_descriminator_yreal_loss: 0.1126\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'descriminator' is not defined",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-18-efb07bf8d861>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      8\u001b[0m             \u001b[0mgenerator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m             \u001b[0mdiscriminator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdiscriminator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m             \u001b[0mlatent_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlatent_dim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m            )\n",
      "\u001b[0;32m<ipython-input-17-100b196a4bc4>\u001b[0m in \u001b[0;36mexample_gan\u001b[0;34m(adversarial_optimizer, path, opt_g, opt_d, nb_epoch, generator, discriminator, latent_dim, targets, loss)\u001b[0m\n\u001b[1;32m     45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m     \u001b[0mgenerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'generator.h5'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m     \u001b[0mdescriminator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'descriminator.h5'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'descriminator' is not defined"
     ],
     "output_type": "error"
    }
   ],
   "source": [
    "latent_dim = 100\n",
    "generator = model_generator()\n",
    "discriminator = model_discriminator()\n",
    "example_gan(AdversarialOptimizerSimultaneous(), \"./output/gan-cifar10\",\n",
    "            opt_g=Adam(1e-4, decay=1e-5),\n",
    "            opt_d=Adam(1e-3, decay=1e-5),\n",
    "            nb_epoch=100,\n",
    "            generator=generator,\n",
    "            discriminator=discriminator,\n",
    "            latent_dim=latent_dim\n",
    "           )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    ""
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3.0
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}